import pandas as pd
from collections import defaultdict
from math import sqrt

def generate_full_magtool_output(coa_county_df, coa_state_df, coa_national_df, species_df,
                                 output_location, chem_name, master_species_list_location,
                                 include_flags, is_row_crop):
    
    # Import species info from the master species list
    ms_df = import_master_species_info(master_species_list_location)
    ms_df["EntityID"] = ms_df.EntityID.astype(str)
    
    # Get direct overlap dataframe
    direct_df = generate_direct_overlap_dataframe(coa_county_df, coa_state_df, coa_national_df,species_df, ms_df)
    
    # If flag column option is selected add the flag columns
    if include_flags:
        direct_df = generate_imputation_flag_columns(direct_df, coa_county_df, species_df)
    
    # Generate the buffer overlap dataframes
    total_overlap_direct_and_buffer, by_state_overlap_and_buffer = generate_buffer_dataframes(coa_county_df, coa_state_df, coa_national_df, species_df, ms_df, chem_name)
    
    # Generate overlap by use buffer df
    overlap_by_use_buffer_sum = generate_overlap_by_use_buffer_sum(coa_county_df, coa_state_df, coa_national_df, species_df, ms_df, is_row_crop)
    
    # Generate overlap by use buffer df
    overlap_by_use_buffer_increment = generate_overlap_by_use_buffer_increment(coa_county_df, coa_state_df, coa_national_df, species_df, ms_df, is_row_crop)
    
    # Save output file
    save_magtool_output( direct_df, total_overlap_direct_and_buffer, by_state_overlap_and_buffer, overlap_by_use_buffer_sum, overlap_by_use_buffer_increment, output_location )
    
def import_master_species_info(master_species_list_location):
    
    # Create a list of columns needed for output
    req_cols = ['EntityID', 'Common Name', 'Scientific Name',
           'Status (Updated November 2020)', 'Population ', 'Family', 'SPCODE',
           'VIPCODE', 'Lead Agency', 'Country', 'WoE Summary Group',
           'Species Group', 'Designated Critical Habitat',
           'Source of Species Effects Determination',
           'Source of Critical Habitat Effects Determination',
           'CONUS/NL48- Species', 'CONUS/NL48- Critical Habitat',
           'Downstream Transport Species',
           'Wetland species for 1500 m consideration (quantitative only)',
           'Region']
    
    # Read in the file
    ms_df = pd.read_csv(master_species_list_location,usecols = req_cols, keep_default_na = False)
    
    ms_df = ms_df[['EntityID', 'Common Name', 'Scientific Name',
           'Status (Updated November 2020)', 'Population ', 'Family', 'SPCODE',
           'VIPCODE', 'Lead Agency', 'Country', 'WoE Summary Group',
           'Species Group', 'Designated Critical Habitat',
           'Source of Species Effects Determination',
           'Source of Critical Habitat Effects Determination',
           'CONUS/NL48- Species', 'CONUS/NL48- Critical Habitat',
           'Downstream Transport Species',
           'Wetland species for 1500 m consideration (quantitative only)',
           'Region']]
    
    return ms_df

def save_magtool_output( direct_df, total_overlap_direct_and_buffer, by_state_overlap_and_buffer, overlap_by_use_buffer_sum, overlap_by_use_buffer_increment, output_location ):
    
    # Change numeric entityIDs to ints
    direct_df.loc[direct_df['EntityID'].str.isnumeric(),'EntityID'] = direct_df.loc[direct_df['EntityID'].str.isnumeric(),'EntityID'].astype(int)
    total_overlap_direct_and_buffer.loc[total_overlap_direct_and_buffer['EntityID'].str.isnumeric(),'EntityID'] = total_overlap_direct_and_buffer.loc[total_overlap_direct_and_buffer['EntityID'].str.isnumeric(),'EntityID'].astype(int)
    by_state_overlap_and_buffer['EntityID'] = by_state_overlap_and_buffer['EntityID'].astype(str)
    by_state_overlap_and_buffer.loc[by_state_overlap_and_buffer['EntityID'].str.isnumeric(),'EntityID'] = by_state_overlap_and_buffer.loc[by_state_overlap_and_buffer['EntityID'].str.isnumeric(),'EntityID'].astype(int)
    overlap_by_use_buffer_sum['EntityID'] = overlap_by_use_buffer_sum['EntityID'].astype(str)
    overlap_by_use_buffer_sum.loc[overlap_by_use_buffer_sum['EntityID'].str.isnumeric(),'EntityID'] = overlap_by_use_buffer_sum.loc[overlap_by_use_buffer_sum['EntityID'].str.isnumeric(),'EntityID'].astype(int)
    overlap_by_use_buffer_increment['EntityID'] = overlap_by_use_buffer_increment['EntityID'].astype(str)
    overlap_by_use_buffer_increment.loc[overlap_by_use_buffer_increment['EntityID'].str.isnumeric(),'EntityID'] = overlap_by_use_buffer_increment.loc[overlap_by_use_buffer_increment['EntityID'].str.isnumeric(),'EntityID'].astype(int)
    
    # Write the dataframes to excel sheets
    with pd.ExcelWriter(output_location) as writer:
        direct_df.to_excel(writer, sheet_name = "Overlap by use direct", index = False)
        direct_df.to_excel(writer, sheet_name = "Alt overlap by use direct", index = False)
        total_overlap_direct_and_buffer.to_excel(writer, sheet_name = "Total overlap direct and buffer", index = False)
        by_state_overlap_and_buffer.to_excel(writer, sheet_name = "By state overlap and buffer", index = False)
        overlap_by_use_buffer_sum.to_excel(writer, sheet_name = "Overlap by use buffer sum", index = False)
        overlap_by_use_buffer_increment.to_excel(writer, sheet_name = "Overlap by use buffer increm", index = False)
        

def generate_direct_overlap_dataframe(coa_county_df, coa_state_df, coa_national_df,
                                      species_df, ms_info):
    
    # Copy dataframes so we're not modifying originals
    coa_county_df, coa_state_df, coa_national_df, species_df = coa_county_df.copy(), coa_state_df.copy(), coa_national_df.copy(), species_df.copy()
    
    # Get a list of unique crops
    unique_crops = extract_unique_crop_list(coa_county_df)
    
    # Get the county level species dict
    species_dict = generate_species_dict(species_df)
    
    # Get county crop overlap areas
    county_overlap_area_df = generate_county_level_crop_area_table(coa_county_df, species_df, unique_crops)
    
    # Apply redundancy adjustment
    county_overlap_area_df = apply_county_species_area_adjustment(county_overlap_area_df, unique_crops, species_dict)
        
    # Roll up to state crop overlap areas
    state_overlap_area_df = roll_up_county_area_table_to_state(county_overlap_area_df, species_df)
    
    # Apply state crop acreage caps
    state_overlap_area_df = apply_state_crop_area_cap(state_overlap_area_df, coa_state_df, unique_crops)
        
    # Roll up to national crop overlap areas
    national_overlap_area_df = roll_up_state_area_table_to_national(state_overlap_area_df, species_df)
    
    # Apply national crop acreage caps
    national_overlap_area_df = apply_national_crop_area_cap(national_overlap_area_df, coa_national_df, unique_crops)
        
    # Convert national areas to overlap percentages 
    national_pct_overlap_df = get_national_overlap_table(national_overlap_area_df,species_df)
    
    # Add "_0" to crop overlap columns
    crop_column_list = [i for i in national_pct_overlap_df if i != "Entity_ID"]
    crop_column_rename_map = {i: i+"_0" for i in crop_column_list}
    national_pct_overlap_df = national_pct_overlap_df.rename(columns=crop_column_rename_map)
    
    # Merge with master species list
    national_pct_overlap_df["Entity_ID"] = national_pct_overlap_df["Entity_ID"].astype(str)
    do_df = ms_info.merge(national_pct_overlap_df,how = "left",left_on = "EntityID", right_on="Entity_ID")
    
    # Clean up columns and fill qualitative species with 0 overlap values
    do_df = do_df.drop(columns = "Entity_ID")
    do_df[list(crop_column_rename_map.values())] = do_df[list(crop_column_rename_map.values())].fillna(value=0)
    
    return do_df
    
def generate_buffer_dataframes(coa_county_df, coa_state_df, coa_national_df,
                                      species_df, ms_info, chem_name):
    
    # Copy dataframes so we're not modifying originals
    coa_county_df, coa_state_df, coa_national_df, species_df = coa_county_df.copy(), coa_state_df.copy(), coa_national_df.copy(), species_df.copy()
    
    # Get a list of unique crops
    unique_crops = extract_unique_crop_list(coa_county_df)
    
    # Get the county level species dict
    species_dict = generate_species_dict(species_df)
    
    # Get the direct county overlap acreages
    county_overlap_area_df = generate_county_level_crop_area_table(coa_county_df, species_df, unique_crops)
    
    # Sum crops in county overlap area to get all ag
    current_buffer_in_m = 0
    field_area = 101171.5
    new_col_name = chem_name+" AA_"+str(current_buffer_in_m)
    county_overlap_area_df[new_col_name] = county_overlap_area_df[unique_crops].sum(axis=1)
    county_overlap_area_df["Species_Area"] = county_overlap_area_df[["GEOID","Entity_ID"]].apply(lambda x: species_dict[(x.GEOID,x.Entity_ID)],axis=1)
    county_overlap_area_df.loc[ county_overlap_area_df[new_col_name] > county_overlap_area_df["Species_Area"], new_col_name] = county_overlap_area_df["Species_Area"]
    
    # Get total county-level direct + drift acreages for buffer intervals
    buffer_intervals = list(range(30,840,30))
    for current_buffer_in_m in buffer_intervals:
        new_col_name = chem_name+" AA_"+str(current_buffer_in_m)
        # Calculate the current direct + drift area in county
        county_overlap_area_df[new_col_name] = county_overlap_area_df[chem_name+" AA_0"].apply(lambda x: (x*4046.86)/field_area*(sqrt(field_area)+current_buffer_in_m*2)**2/4046.86)
        # Cap at county species area
        county_overlap_area_df.loc[county_overlap_area_df[new_col_name]> county_overlap_area_df["Species_Area"], new_col_name] = county_overlap_area_df["Species_Area"]
    
    # Convert cumulative sum values to marginal increases
    for current_buffer_in_m in buffer_intervals[::-1]:
        current_col_name = chem_name+" AA_"+str(current_buffer_in_m)
        previous_col_name = chem_name+" AA_"+str(current_buffer_in_m-30)
        county_overlap_area_df[current_col_name] = county_overlap_area_df[current_col_name] - county_overlap_area_df[previous_col_name]
        
    # Add total buffer columns
    for current_buffer_in_m in [305,792,1500]:
        new_col_name = chem_name+" AA_"+str(current_buffer_in_m)+"_total"
        county_overlap_area_df[new_col_name] = county_overlap_area_df[chem_name+" AA_0"].apply(lambda x: (x*4046.86)/field_area*(sqrt(field_area)+current_buffer_in_m*2)**2/4046.86)
        county_overlap_area_df.loc[county_overlap_area_df[new_col_name]> county_overlap_area_df["Species_Area"], new_col_name] = county_overlap_area_df["Species_Area"]
        county_overlap_area_df[new_col_name] = county_overlap_area_df[new_col_name] - county_overlap_area_df[chem_name+" AA_0"]
    
    # Drop columns not needed for calculation
    county_overlap_area_df.drop(columns = unique_crops, inplace=True)
    
    # Roll up to state
    state_overlap_area_df = roll_up_county_area_table_to_state(county_overlap_area_df, species_df)
    
    # Roll up to national
    national_overlap_area_df = roll_up_state_area_table_to_national(state_overlap_area_df, species_df)
    
    by_state_overlap_and_buffer = generate_state_overlap_output_df(state_overlap_area_df, national_overlap_area_df, species_df, chem_name, ms_info)
    
    total_overlap_direct_and_buffer = generate_national_direct_and_buffer_output_df(national_overlap_area_df, species_df, ms_info)
    
    return [total_overlap_direct_and_buffer, by_state_overlap_and_buffer]
    
def generate_state_overlap_output_df(state_overlap_area_df, national_overlap_area_df, species_df, chem_name, ms_info):
    
    # Copy state df so we don't modify the original
    state_overlap_area_df = state_overlap_area_df.copy()
    
    # Update species acreage column to reflect true national species acres
    national_species_acres = generate_national_species_acres_dict(species_df)
    national_overlap_area_df["Species_Area"] = national_overlap_area_df["Entity_ID"].apply(lambda x: national_species_acres[x])
    
    # Get state FP lookup from species_df
    state_fp = defaultdict(int,species_df.set_index("State").STATEFP.to_dict())
    
    # Drop the total overlap and species columns
    state_overlap_area_df.drop(columns = [chem_name+" AA_"+str(i)+"_total" for i in [305,792,1500]]+["Species_Area"], inplace = True)
    
    # Insert the additional columns
    state_overlap_area_df["ACRES (Total)"] = state_overlap_area_df["Entity_ID"].apply(lambda x: national_species_acres[x])
    state_overlap_area_df["STATEFP"] = state_overlap_area_df["State"].apply(lambda x: state_fp[x])
    
    # Rename the overlap acreage columns
    original_aa_columns = [i for i in state_overlap_area_df.columns if "AA_" in i]
    aa_rename_map = {i: int(i.split("_")[-1]) for i in original_aa_columns}
    state_overlap_area_df.rename(columns = aa_rename_map, inplace = True)
    
    # Rename Entity_ID and State Columns
    index_rename_map = {"Entity_ID": "EntityID", "State": "STATE"}
    state_overlap_area_df.rename(columns = index_rename_map, inplace = True)
    
    # Reorganize columns to match the template
    state_overlap_area_df = state_overlap_area_df[["ACRES (Total)", "EntityID", "STATEFP", "STATE"]+list(aa_rename_map.values())]
    
    # Sort by EntityID so species are grouped together
    state_overlap_area_df['EntityID'] = state_overlap_area_df['EntityID'].astype(str)
    state_overlap_area_df.sort_values(by = 'EntityID', inplace = True)
    
    
    return state_overlap_area_df
    
def generate_national_direct_and_buffer_output_df(national_overlap_area_df, species_df, ms_info):
    
    # Get a list of AA columns
    aa_cols = [c for c in national_overlap_area_df if "AA_" in c]
    
    # Update species acreage column to reflect true national species acres
    national_species_acres = generate_national_species_acres_dict(species_df)
    national_overlap_area_df["Species_Area"] = national_overlap_area_df["Entity_ID"].apply(lambda x: national_species_acres[x])
    
    # Divide all of the AA columns by species acres to get percent overlap
    for c in aa_cols:
        national_overlap_area_df[c] = national_overlap_area_df[c]/national_overlap_area_df["Species_Area"]*100
    
    # Drop the Species_Area column
    national_overlap_area_df.drop(columns = ["Species_Area"], inplace = True)
    
    # Merge with master species list
    national_overlap_area_df["Entity_ID"] = national_overlap_area_df["Entity_ID"].astype(str)
    merged_df = ms_info.merge(national_overlap_area_df,how = "left",left_on = "EntityID", right_on="Entity_ID")
    
    # Get rid of the redundant EntityID
    merged_df.drop(columns = ["Entity_ID"], inplace = True)
    
    # Fill qualitative species entries with zeros
    merged_df[aa_cols] = merged_df[aa_cols].fillna(value=0)
    
    return merged_df

def generate_overlap_by_use_buffer_sum(coa_county_df, coa_state_df, coa_national_df, species_df, ms_info, is_row_crop):
    
    # Copy dataframes so we're not modifying originals
    coa_county_df, coa_state_df, coa_national_df, species_df = coa_county_df.copy(), coa_state_df.copy(), coa_national_df.copy(), species_df.copy()
    
    # Get a list of unique crops
    unique_crops = extract_unique_crop_list(coa_county_df)
    
    # Get the county level species dict
    species_dict = generate_species_dict(species_df)
    
    # Get the direct county overlap acreages
    county_overlap_area_df = generate_county_level_crop_area_table(coa_county_df, species_df, unique_crops)
    
    # Create a column with species area
    county_overlap_area_df["Species_Area"] = county_overlap_area_df[["GEOID","Entity_ID"]].apply(lambda x: species_dict[(x.GEOID,x.Entity_ID)],axis=1)
    
    # Loop through crops and buffers to get total acreage values
    buffers = list(range(0,840,30))+[305,792,1500]
    for crop in unique_crops:
        for current_buffer_in_m in buffers:
            new_col_name = crop + "_" +str(current_buffer_in_m)
            if is_row_crop[crop]:
                field_area = 2023430
            else:
                field_area = 101171.5
            county_overlap_area_df[new_col_name] = county_overlap_area_df[crop].apply(lambda x: (x*4046.86)/field_area*(sqrt(field_area)+current_buffer_in_m*2)**2/4046.86)
            county_overlap_area_df.loc[county_overlap_area_df[new_col_name]> county_overlap_area_df["Species_Area"], new_col_name] = county_overlap_area_df["Species_Area"]
    
    # Roll up directly to national with no capping 
    county_overlap_area_df.drop(columns = ["GEOID"]+unique_crops, inplace = True)
    national_overlap_area_df = county_overlap_area_df.groupby("Entity_ID").sum().reset_index()
    
    # Update species acreage column to reflect true national species acres
    national_species_acres = generate_national_species_acres_dict(species_df)
    national_overlap_area_df["Species_Area"] = national_overlap_area_df["Entity_ID"].apply(lambda x: national_species_acres[x])
    
    # Convert acreage into percentages
    all_crop_columns = []
    for crop in unique_crops:
        for buffer in buffers:
            current_column_name = crop + "_" + str(buffer)
            all_crop_columns.append(current_column_name)
            national_overlap_area_df[current_column_name] = national_overlap_area_df[current_column_name] / national_overlap_area_df["Species_Area"]*100
    
    national_overlap_area_df = national_overlap_area_df[["Entity_ID"]+all_crop_columns]
    
    left_df = ms_info[["EntityID","Common Name","Scientific Name"]]
    
    out_df = left_df.merge(national_overlap_area_df,how = "left",left_on = "EntityID", right_on="Entity_ID")
    
    out_df.drop(columns = ["Entity_ID"], inplace = True)
    
    out_df[all_crop_columns] = out_df[all_crop_columns].fillna(value=0)
    
    return out_df
    
def generate_overlap_by_use_buffer_increment(coa_county_df, coa_state_df, coa_national_df, species_df, ms_info, is_row_crop):
    
    # Copy dataframes so we're not modifying originals
    coa_county_df, coa_state_df, coa_national_df, species_df = coa_county_df.copy(), coa_state_df.copy(), coa_national_df.copy(), species_df.copy()
    
    # Get a list of unique crops
    unique_crops = extract_unique_crop_list(coa_county_df)
    
    # Get the county level species dict
    species_dict = generate_species_dict(species_df)
    
    # Get the direct county overlap acreages
    county_overlap_area_df = generate_county_level_crop_area_table(coa_county_df, species_df, unique_crops)
    
    # Create a column with species area
    county_overlap_area_df["Species_Area"] = county_overlap_area_df[["GEOID","Entity_ID"]].apply(lambda x: species_dict[(x.GEOID,x.Entity_ID)],axis=1)
    
    # Loop through crops and buffers to get total acreage values
    buffers = list(range(0,840,30))
    for crop in unique_crops:
        for current_buffer_in_m in buffers:
            new_col_name = crop + "_" +str(current_buffer_in_m)
            if is_row_crop[crop]:
                field_area = 2023430
            else:
                field_area = 101171.5
            county_overlap_area_df[new_col_name] = county_overlap_area_df[crop].apply(lambda x: (x*4046.86)/field_area*(sqrt(field_area)+current_buffer_in_m*2)**2/4046.86)
            county_overlap_area_df.loc[county_overlap_area_df[new_col_name]> county_overlap_area_df["Species_Area"], new_col_name] = county_overlap_area_df["Species_Area"]
            
    # Convert cumulative sum values to marginal increases
    for crop in unique_crops:
        for current_buffer_in_m in buffers[:0:-1]:
            current_col_name = crop + "_" +str(current_buffer_in_m)
            previous_col_name = crop + "_" +str(current_buffer_in_m-30)
            county_overlap_area_df[current_col_name] = county_overlap_area_df[current_col_name] - county_overlap_area_df[previous_col_name]
    
    # Roll up directly to national with no capping 
    county_overlap_area_df.drop(columns = ["GEOID"]+unique_crops, inplace = True)
    national_overlap_area_df = county_overlap_area_df.groupby("Entity_ID").sum().reset_index()
    
    # Update species acreage column to reflect true national species acres
    national_species_acres = generate_national_species_acres_dict(species_df)
    national_overlap_area_df["Species_Area"] = national_overlap_area_df["Entity_ID"].apply(lambda x: national_species_acres[x])
    
    # Convert acreage into percentages
    all_crop_columns = []
    for crop in unique_crops:
        for buffer in buffers:
            current_column_name = crop + "_" + str(buffer)
            all_crop_columns.append(current_column_name)
            national_overlap_area_df[current_column_name] = national_overlap_area_df[current_column_name] / national_overlap_area_df["Species_Area"]*100
    
    national_overlap_area_df = national_overlap_area_df[["Entity_ID"]+all_crop_columns]
    
    left_df = ms_info[["EntityID","Common Name","Scientific Name"]]
    
    out_df = left_df.merge(national_overlap_area_df,how = "left",left_on = "EntityID", right_on="Entity_ID")
    
    out_df.drop(columns = ["Entity_ID"], inplace = True)
    
    out_df[all_crop_columns] = out_df[all_crop_columns].fillna(value=0)
    
    return out_df  
            
    
    
    

def generate_species_dict(species_df):
    # Get dictionary of species acres with GEOID and EntityID as keys
    return defaultdict(int,species_df.set_index(['GEOID','EntityID']).Area_in_Acres.to_dict())

def generate_national_species_acres_dict(species_df):
    sdf = species_df[["EntityID","Area_in_Acres"]].copy()
    sdf_nat = sdf.groupby("EntityID").sum().reset_index()
    return sdf_nat.set_index("EntityID").Area_in_Acres.to_dict()
    

def generate_crop_dict(crop_df):
    # Get dictionary of county-level crop acres with GEOID and crop as keys
    return defaultdict(int,crop_df.set_index(['GEOID','CONCAT USE SITE']).VALUE.to_dict())

def generate_state_lookup(species_df):
    # Get dictionary to map GEOID to state
    return species_df.set_index(['GEOID']).State.to_dict()

def generate_max_crop_area_by_state_lookup(state_crop_df):
    # Get dictionary of state-level crop acres with state and crop as keys
    return defaultdict(int,state_crop_df.set_index(['Location','CONCAT USE SITE']).Value.to_dict())

def generate_national_max_crop_area_lookup(state_crop_df):
    # Get dictionary of national-level crop acres with crop as key
    return defaultdict(int,state_crop_df.set_index(['CONCAT USE SITE']).VALUE.to_dict())

def preprocess_county_level_crop_df(county_crop_df):
    # Cast columns used for indexing to compatible types
    county_crop_df['GEOID'] = county_crop_df['GEOID'].astype('int')
    county_crop_df['CONCAT USE SITE'] = county_crop_df['CONCAT USE SITE'].astype('str')
    
    return county_crop_df

def filter_crops(crop_df,included_crops):
    # Keep only specified crops for overlap analysis
    return crop_df[crop_df['CONCAT USE SITE'].isin(included_crops)]


def apply_county_species_area_adjustment(county_crop_area_df, unique_crops, species_dict):
    
    # Calculate the adjustment factor (sum of crop area divided by species area)
    adjustment = county_crop_area_df[unique_crops].sum(axis=1) / county_crop_area_df[["GEOID","Entity_ID"]].apply(lambda x: species_dict[(x.GEOID,x.Entity_ID)],axis=1)
    
    # If the adjustment factor is less than 1 (total crop area is less than species area) set it to 1
    adjustment[adjustment<1] = 1.0
    
    # Divide all crop areas in each row by the corresponding adjustment factor (resulting in them being scaled proportionally if the total exceeds species area)
    for c in unique_crops:
        county_crop_area_df[c] = county_crop_area_df[c] / adjustment
    
    return county_crop_area_df

def extract_unique_crop_list(coa_county_crop_df):
    return sorted(list(coa_county_crop_df['CONCAT USE SITE'].unique()))
    
def generate_county_level_crop_area_table(county_crop_df, species_df, unique_crops):
    
    # Generate crop and species dictionaries
    species_dict = generate_species_dict(species_df)
    county_crop_dict = generate_crop_dict(county_crop_df)
    
    # Generate county-level overlap dictionary
    county_crop_area_table = []
    for g_id,e_id in species_dict:
        current_row = []
        for crop in unique_crops:
            current_row.append(county_crop_dict[(g_id,crop)])
        county_crop_area_table.append([g_id,e_id]+current_row)
    
    county_crop_area_df = pd.DataFrame(county_crop_area_table,columns=['GEOID','Entity_ID']+unique_crops)
    
    return county_crop_area_df

def apply_state_crop_area_cap(rollup_state_crop_area_df, coa_state_crop_area_df, unique_crops):
    
    # Generate a dictionary of state crop caps
    state_crop_caps = generate_max_crop_area_by_state_lookup(coa_state_crop_area_df)
    
    # Loop through crop/state combos and cap values in the rollup dataframe that exceed the state acreage
    unique_states = rollup_state_crop_area_df['State'].unique()
    for crop in unique_crops:
        for state in unique_states:
            rollup_state_crop_area_df.loc[rollup_state_crop_area_df["State"]==state , crop] = rollup_state_crop_area_df.loc[rollup_state_crop_area_df["State"]==state , crop].apply(lambda x: min(x,state_crop_caps[state,crop]))
            
    return rollup_state_crop_area_df

def roll_up_county_area_table_to_state(county_crop_area_df, species_df):
    
    # Get GEOID to state dict
    geoid_to_state_dict = generate_state_lookup(species_df)
    
    # Generate species dictionary
    species_dict = generate_species_dict(species_df)
    
    # Get state-level species acreage dictionary
    state_species_dict = defaultdict(int)
    for g_id,e_id in species_dict:
        state_species_dict[(geoid_to_state_dict[g_id],e_id)] += species_dict[(g_id,e_id)]
    
    # Append a state column to the crop area dataframe 
    county_crop_area_df['State'] = county_crop_area_df['GEOID'].apply(lambda x: geoid_to_state_dict[x])
    county_crop_area_df = county_crop_area_df.drop(['GEOID'],axis=1)
    
    # Aggregate by state and sum
    rollup_state_crop_area_df = county_crop_area_df.groupby(by = ['State','Entity_ID']).sum().reset_index()

    return rollup_state_crop_area_df

def apply_national_crop_area_cap(rollup_national_crop_area_df, coa_national_crop_area_df, unique_crops):
    
    # Generate a dictionary of national crop caps
    national_crop_caps = generate_national_max_crop_area_lookup(coa_national_crop_area_df)
    
    for crop in unique_crops:
        rollup_national_crop_area_df[crop] = rollup_national_crop_area_df[crop].apply(lambda x: min(x,national_crop_caps[crop]))
        
    return rollup_national_crop_area_df

def roll_up_state_area_table_to_national(rollup_state_crop_area_df, species_df):
    # Generate species dictionary
    species_dict = generate_species_dict(species_df)
    
    # Get state-level species acreage dictionary
    national_species_dict = defaultdict(int)
    for g_id,e_id in species_dict:
        national_species_dict[e_id] += species_dict[(g_id,e_id)]
    
    # Aggregate by Entity ID
    rollup_state_crop_area_df = rollup_state_crop_area_df.drop(['State'],axis=1)
    rollup_national_crop_area_df = rollup_state_crop_area_df.groupby(by = ['Entity_ID']).sum().reset_index()
    
    return rollup_national_crop_area_df

def get_national_overlap_table(rollup_national_crop_area_df,species_df):
    # Generate species dictionary
    species_dict = generate_species_dict(species_df)
    
    # Get state-level species acreage dictionary
    national_species_dict = defaultdict(int)
    for g_id,e_id in species_dict:
        national_species_dict[e_id] += species_dict[(g_id,e_id)]
    
    unique_crops = set(rollup_national_crop_area_df.columns)
    unique_crops.remove('Entity_ID')
    
    # Create a copy so we don't modify the input table
    overlap_df = rollup_national_crop_area_df.copy()
    for crop in unique_crops:
        overlap_df[crop]=overlap_df.apply(lambda x: x[crop]/national_species_dict[x['Entity_ID']]*100,axis=1)
    
    return overlap_df


def generate_imputation_flag_columns(do_output_df, county_crop_area_df, species_df):
    
    # Get a list of crops
    crops = county_crop_area_df["CONCAT USE SITE"].unique().tolist()
    
    # Generate a dictionary of geoids for each flag/crop combo
    crop_flag_dict = defaultdict(set)
    grouped = county_crop_area_df.groupby(by = ["CONCAT USE SITE", "Imputation"])
    for name, group in grouped:
        crop, flag = name
        crop_flag_dict[(crop, flag)] = set(group['GEOID'].unique())
        
    # Generate a dictionary of geoids for each species
    species_dict = defaultdict(set)
    grouped = species_df.groupby(by = "EntityID")
    for eid, group in grouped:
        species_dict[eid] = set(group['GEOID'].unique())
        
    # Get flags for each species/crop combo
    eid_crop_flags = defaultdict(list)
    for crop, flag in crop_flag_dict:
        for eid in species_dict:
            if crop_flag_dict[(crop, flag)].intersection(species_dict[eid]):
                eid_crop_flags[(str(eid),crop)].append(flag)
    
    crop_cols = [i+"_0" for i in crops]
    
    for crop in crops:
        crop_cols.append(str(crop)+"_Flags")
        do_output_df[str(crop)+"_Flags"] = do_output_df['EntityID'].apply(lambda x: ", ".join(eid_crop_flags[(x,crop)]))
        
    ordered_cols = [i for i in do_output_df.columns if i not in crop_cols] + sorted(crop_cols)
    
    do_output_df = do_output_df[ordered_cols]
    
    return do_output_df
    

def import_coa_county_df(file_location):
    
    # List columns to import
    req_columns = ["CONCAT USE SITE",
                   "STATE_NAME",
                   "VALUE",
                   "GEOID",
                   "Imputation"]
    
    # List types of imported columns
    dtypes = {"CONCAT USE SITE": str,
              "STATE_NAME": str,
              "VALUE": float,
              "GEOID": int,
              "Imputation": str}
    
    return pd.read_csv(file_location,usecols=req_columns,dtype=dtypes)
    
def import_coa_state_df(file_location):
    
    # List columns to import
    req_columns = ["Location",
                   "CONCAT USE SITE",
                   "Value",
                   "Imputation"]
    
    # List types of imported columns 
    dtypes = {"Location": str,
              "CONCAT USE SITE": str,
              "Value": float,
              "Imputation": str}
    
    return pd.read_csv(file_location,usecols=req_columns,dtype=dtypes)

def import_coa_national_df(file_location):
    
    # List columns to import
    req_columns = ["CONCAT USE SITE",
                   "VALUE"]
    
    # List types of imported columns 
    dtypes = {"CONCAT USE SITE": str,
              "VALUE": float}
    
    return pd.read_csv(file_location,usecols=req_columns,dtype=dtypes)

def is_float(element):
    try:
        float(element)
        return True
    except ValueError:
        return False

def import_species_df(file_location):
    # List columns to import
    req_columns = ["EntityID",
                   "STATEFP",
                   "GEOID",
                   "NAME_1",
                   "State",
                   "Area_in_m2",
                   "Area_in_Acres"]
    
    # List types of imported columns 
    dtypes = {"EntityID": str,
              "STATEFP": int,
              "GEOID": int,
              "NAME_1": str,
              "State": str,
              "Area_in_m2": float,
              "Area_in_Acres": float}
    
    # Read in the species file
    df = pd.read_csv(file_location,usecols=req_columns,dtype=dtypes)
    
    # Ensure EntityID is imported correctly
    df.loc[df["EntityID"].apply(lambda x: is_float(x)),"EntityID"] = df.loc[df["EntityID"].apply(lambda x: is_float(x)),"EntityID"].astype(float).astype(int).astype(str)
    
    return df